#!/usr/bin/env python


import rospy
import numpy as np
import math
import os
import cv2
from enum import Enum
from std_msgs.msg import UInt8
from sensor_msgs.msg import Image, CompressedImage
from cv_bridge import CvBridge, CvBridgeError

from rospkg import RosPack

MIN_MATCH_COUNT = 9
MIN_MSE_DECISION = 70000

class DetectLogo(object):
    def __init__(self):
        self.fnPreproc()

        self.counter = 0
        self.pub_detected_logos = rospy.Publisher('/detect/image_output', Image, queue_size = 1)
        self.cvBridge = CvBridge()
        self.sub_image_original = rospy.Subscriber('/camera/rgb/image_raw', Image, self.cbFindImage, queue_size = 1)

        self.pub_logo = rospy.Publisher('/detect/logo', UInt8, queue_size=1)

        self.Logos = Enum('Logos', 'cocacola rolex pirelli seiko heineken martini redbull chess')

    def fnPreproc(self):
        # Initiate SIFT detector
        self.sift = cv2.xfeatures2d.SIFT_create()

        dir_path = RosPack().get_path('turtlebot3_autorace_simulation') + '/file/detect_logos/'
        os.chdir(dir_path)

        self.img_chess_flag = cv2.imread(dir_path + 'chess_flag.png',0)
        self.img_cocacola = cv2.imread(dir_path + 'cocacola.png',0)
        self.img_heineken = cv2.imread(dir_path + 'heineken.jpg',0)
        self.img_martini = cv2.imread(dir_path + 'martini.png',0)
        self.img_pirelli = cv2.imread(dir_path + 'pirelli.png',0)
        self.img_redbull = cv2.imread(dir_path + 'redbull.jpg',0)
        self.img_rolex = cv2.imread(dir_path + 'rolex.png',0)
        self.img_seiko = cv2.imread(dir_path + 'seiko.png',0)

        # Get keypoints and descriptors
        self.kp_chess, self.des_chess = self.sift.detectAndCompute(self.img_chess_flag,None)
        self.kp_cocacola, self.des_cocacola = self.sift.detectAndCompute(self.img_cocacola,None)
        self.kp_heineken, self.des_heineken = self.sift.detectAndCompute(self.img_heineken,None)
        self.kp_martini, self.des_martini = self.sift.detectAndCompute(self.img_martini,None)
        self.kp_pirelli, self.des_pirelli = self.sift.detectAndCompute(self.img_pirelli,None)
        self.kp_redbull, self.des_redbull = self.sift.detectAndCompute(self.img_redbull,None)
        self.kp_rolex, self.des_rolex = self.sift.detectAndCompute(self.img_rolex,None)
        self.kp_seiko, self.des_seiko = self.sift.detectAndCompute(self.img_seiko,None)

        # Setting algorithm params
        FLANN_INDEX_KDTREE = 0
        index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
        search_params = dict(checks = 50)
        self.flann = cv2.FlannBasedMatcher(index_params, search_params)

    def fnCalcMSE(self, arr1, arr2):
        # Calculate Mean Square Error
        squared_diff = (arr1 - arr2) ** 2
        sum = np.sum(squared_diff)
        num_all = arr1.shape[0] * arr1.shape[1] #cv_image_input and 2 should have same shape
        err = sum / num_all
        return err

    def cbFindImage(self, image_msg):
        # drop the frame to 1/5 (6fps) because of the processing speed. This is up to your computer's operating power.
        if self.counter % 3 != 0:
            self.counter += 1
            return
        else:
            self.counter = 1
        
        cv_image_input = self.cvBridge.imgmsg_to_cv2(image_msg, "bgr8")

        # find the keypoints and descriptors with SIFT
        kp_input, des_input = self.sift.detectAndCompute(cv_image_input,None)

        # Brute force matching
        chess = self.detect_chess_flag(kp_input, des_input)
        cocacola = self.detect_cocacola(kp_input, des_input, cv_image_input)
        martini = self.detect_martini(kp_input, des_input)
        pirelli = self.detect_pirelli(kp_input, des_input, cv_image_input)
        redbull = self.detect_redbull(kp_input, des_input)
        rolex = self.detect_rolex(kp_input, des_input, cv_image_input)
        seiko = self.detect_seiko(kp_input, des_input)
        heineken = self.detect_heineken(kp_input, des_input)
        self.draw_matches(  cv_image_input, kp_input, chess, cocacola, martini, pirelli,\
                            redbull, rolex, seiko, heineken)

    def draw_matches(self, cv_image_input, kp_input, chess, cocacola, martini, pirelli, redbull, rolex, seiko, heineken):
        if (chess[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = chess[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_chess_flag,self.kp_cocacola,chess[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (cocacola[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = cocacola[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_cocacola,self.kp_cocacola,cocacola[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (martini[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = martini[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_martini,self.kp_martini,martini[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (rolex[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = rolex[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_rolex,self.kp_rolex,rolex[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (pirelli[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = pirelli[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_pirelli,self.kp_pirelli,pirelli[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (redbull[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = redbull[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_redbull,self.kp_redbull,redbull[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (heineken[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = heineken[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_heineken,self.kp_heineken,heineken[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        elif (seiko[0] != None):
            draw_params2 = dict(matchColor = (0,0,255),
                                singlePointColor = None,
                                matchesMask = seiko[0], # draw only inliers
                                flags = 2)
            final2 = cv2.drawMatches(cv_image_input,kp_input,self.img_seiko,self.kp_seiko,seiko[1],None,**draw_params2)
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(final2, "bgr8"))
        else:
            self.pub_detected_logos.publish(self.cvBridge.cv2_to_imgmsg(cv_image_input, "bgr8"))

    def detect_chess_flag(self, kp_input, des_input):
        matches_chess = self.flann.knnMatch(des_input,self.des_chess,k=2)
        MIN_MATCH_COUNT = 2
        MIN_MSE_DECISION = 35000
        good_chess = []
        for m,n in matches_chess:
            if m.distance < 0.7*n.distance:
                good_chess.append(m)

        if len(good_chess)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_chess ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_chess[m.trainIdx].pt for m in good_chess ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_chess = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.chess.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_chess = None
        return (matchesMask_chess, good_chess)

    def detect_cocacola(self, kp_input, des_input, cv_image_input):
        matches_cocacola = self.flann.knnMatch(des_input,self.des_cocacola,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_cocacola = []

        for m,n in matches_cocacola:
            if m.distance < 0.7*n.distance:
                good_cocacola.append(m)

        if len(good_cocacola)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_cocacola ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_cocacola[m.trainIdx].pt for m in good_cocacola ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_cocacola = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.cocacola.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_cocacola = None
        return (matchesMask_cocacola, good_cocacola)

    def detect_rolex(self, kp_input, des_input, cv_image_input):
        matches_rolex = self.flann.knnMatch(des_input,self.des_rolex,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_rolex = []

        for m,n in matches_rolex:
            if m.distance < 0.7*n.distance:
                good_rolex.append(m)
        if len(good_rolex)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_rolex ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_rolex[m.trainIdx].pt for m in good_rolex ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_rolex = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)

            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.rolex.value
                self.pub_logo.publish(msg_image)
                
        else:
            matchesMask_rolex = None
        return (matchesMask_rolex, good_rolex)

    def detect_pirelli(self, kp_input, des_input, cv_image_input):
        matches_pirelli = self.flann.knnMatch(des_input,self.des_pirelli,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_pirelli = []

        for m,n in matches_pirelli:
            if m.distance < 0.7*n.distance:
                good_pirelli.append(m)
        if len(good_pirelli)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_pirelli ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_pirelli[m.trainIdx].pt for m in good_pirelli ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_pirelli = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.pirelli.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_pirelli = None
        return (matchesMask_pirelli, good_pirelli)

    def detect_seiko(self, kp_input, des_input):
        matches_seiko = self.flann.knnMatch(des_input,self.des_seiko,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_seiko = []

        for m,n in matches_seiko:
            if m.distance < 0.7*n.distance:
                good_seiko.append(m)
        if len(good_seiko)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_seiko ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_seiko[m.trainIdx].pt for m in good_seiko ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_seiko = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.seiko.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_seiko = None
        return (matchesMask_seiko, good_seiko)

    def detect_heineken(self, kp_input, des_input):
        matches_heineken = self.flann.knnMatch(des_input,self.des_heineken,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_heineken = []

        for m,n in matches_heineken:
            if m.distance < 0.7*n.distance:
                good_heineken.append(m)
        if len(good_heineken)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_heineken ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_heineken[m.trainIdx].pt for m in good_heineken ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_heineken = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.heineken.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_heineken = None
        return (matchesMask_heineken, good_heineken)

    def detect_martini(self, kp_input, des_input):
        matches_martini = self.flann.knnMatch(des_input,self.des_martini,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_martini = []

        for m,n in matches_martini:
            if m.distance < 0.7*n.distance:
                good_martini.append(m)
        if len(good_martini)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_martini ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_martini[m.trainIdx].pt for m in good_martini ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_martini = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.martini.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_martini = None
        return (matchesMask_martini, good_martini)

    def detect_redbull(self, kp_input, des_input):
        matches_redbull = self.flann.knnMatch(des_input,self.des_redbull,k=2)
        global MIN_MATCH_COUNT
        global MIN_MSE_DECISION
        good_redbull = []

        for m,n in matches_redbull:
            if m.distance < 0.7*n.distance:
                good_redbull.append(m)
        if len(good_redbull)>MIN_MATCH_COUNT:
            src_pts = np.float32([ kp_input[m.queryIdx].pt for m in good_redbull ]).reshape(-1,1,2)
            dst_pts = np.float32([ self.kp_redbull[m.trainIdx].pt for m in good_redbull ]).reshape(-1,1,2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
            matchesMask_redbull = mask.ravel().tolist()

            mse = self.fnCalcMSE(src_pts, dst_pts)
            
            if mse < MIN_MSE_DECISION:
                msg_image = UInt8()
                msg_image.data = self.Logos.redbull.value
                self.pub_logo.publish(msg_image)
        else:
            matchesMask_redbull = None
        return (matchesMask_redbull, good_redbull)

    def main(self):
        rospy.spin()


if __name__ == '__main__':
    rospy.init_node('detect_logo')
    node = DetectLogo()
    node.main()

